大部分內容都講完哩,今天來講Q訓練部分來做收尾~
超過observe,就會開始訓練模型。
if t > OBSERVE:
minibatch = random.sample(D, BATCH) # 從replay memory隨機抽取資料
inputs = np.zeros((BATCH, s_t.shape[1], s_t.shape[2], s_t.shape[3])) #batch_size, 80, 80, 4
targets = np.zeros((inputs.shape[0], ACTIONS)) # batch_size, actions
根據batch大小,逐一實作Q現實。
for i in range(0, len(minibatch)):
state_t = minibatch[i][0] # 圖像
action_t = minibatch[i][1] # 執行的動作index
reward_t = minibatch[i][2] # 輸入action後output的reward
state_t1 = minibatch[i][3] # 輸入action後output的下個state
terminal = minibatch[i][4] # 輸入action後output回報遊戲有無結束
inputs[i:i + 1] = state_t # state塞回train要用的array
targets[i] = model.predict(state_t) # Q估計
Q_sa = model.predict(state_t1) # 下一步的Q估計
if terminal:
targets[i, action_t] = reward_t # 如果是最後步,這邊就只有reward
else:
targets[i, action_t] = reward_t + GAMMA * np.max(Q_sa) # Q現實
loss += model.train_on_batch(inputs, targets)
loss_df.loc[len(loss_df)] = loss # 紀錄loss
q_values_df.loc[len(q_values_df)] = np.max(Q_sa) # 紀錄q_value
s_t = initial_state if terminal else s_t1
t = t + 1
如果訓練次數能被1000整除,則遊戲暫停,儲存資料。會直接暫停是因為資料IO要時間,怕影響到採樣。
if t % 1000 == 0:
print("Now we save model")
game_state._game.pause() #pause game while saving to filesystem
model.save_weights("model.h5", overwrite=True)
save_obj(D,"D") #saving episodes
save_obj(t,"time") #caching time steps
save_obj(epsilon,"epsilon") #cache epsilon to avoid repeated randomness in actions
loss_df.to_csv("./objects/loss_df.csv",index=False)
scores_df.to_csv("./objects/scores_df.csv",index=False)
actions_df.to_csv("./objects/actions_df.csv",index=False)
q_values_df.to_csv(q_value_file_path,index=False)
with open("model.json", "w") as outfile:
json.dump(model.to_json(), outfile)
game_state._game.resume()
接下來把之前學的都寫進主程序,就可以開始訓練拉!
def playGame(observe=False):
game = Game()
dino = DinoAgent(game)
game_state = Game_sate(dino,game)
model = buildmodel()
try:
trainNetwork(model,game_state,observe=observe)
except StopIteration:
game.end()
playGame(observe=False)
訓練小恐龍主程序unit5_dino
小恐龍跳跳跳筆者是train了3天至1400多分,專案的原作者的demo則跑到4000多分,有個地方令我感到有些困擾,個人覺得在採樣的時候,樣本跟樣本時間間隔不一致,筆者猜因這個關係而導致收斂跟效果都沒想像中的好,畢竟你看小恐龍玩的規則很單純,但真的要走到高分的次數其實很少,就我觀察小恐龍有時會跳跳躍採到仙人掌或直接撞上,最好的表現是可以在快接近的時候跳躍。我的想法是小恐龍原本可收斂的很快,但因為採樣時間的不均一,導致效果差。
延伸下來我們可以思考幾個方式,例如兩隻程序一個是小恐龍的主程序走,負責採樣,另一支則輸出動作跟訓練,這當然會延伸其他問題例如兩支程序的速度不同,如何互相配合?類神經的速度如果能跟得上環境那還好說,但實際看起來卻並非如此。還有個最暴力方式就是不斷讓環境該停止就停,以此控制採樣跟執行action的間隔,兩個都有些想法,不過這可能要等之後有段時間再來實行了><
用11篇講解keras實作DQN,學到這邊同學有沒很有成就感呢?到這再玩其他強化學習專案就就可以很快上手囉!恭喜同學堅持到今天~接下來幾篇我們會講解進階的DQN方法以及自建環境,大家明天見拉~